/*
 * Copyright © 2021 Dowsure
 * https://www.dowsure.com/
 *
 * All rights reserved.
 */

package com.dowsure.apisaas.util.smalgorithm;

import org.bouncycastle.asn1.*;
import org.bouncycastle.asn1.gm.GMNamedCurves;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.asn1.sec.ECPrivateKey;
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import org.bouncycastle.asn1.x9.X9ECParameters;
import org.bouncycastle.crypto.*;
import org.bouncycastle.crypto.engines.SM2Engine;
import org.bouncycastle.crypto.generators.ECKeyPairGenerator;
import org.bouncycastle.crypto.params.*;
import org.bouncycastle.crypto.signers.SM2Signer;
import org.bouncycastle.math.ec.ECConstants;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.math.BigInteger;
import java.security.SecureRandom;

/**
 * sm2算法工具类，提供加解密、签名验签功能
 * <p>因为数据加解密都是对字节数据加解密，因此需要注意加密前和解密后的字符集保持一致
 * <p>若无特殊说明，接口接收的都是原始的二进制数据，被hex或者base64编码的数据，务必解码之后再传进来
 * @author
 */
public class SM2Util {


    private static final X9ECParameters sm2p256v1 = GMNamedCurves.getByName("sm2p256v1");
    private static final int SM3_DIGEST_LENGTH_32 = 32;
    private static final byte[] defaultUserID = "1234567812345678".getBytes();


    /**
     * sm2加密
     *
     * @param publicKey 公钥，二进制数据，若被编码的，请解码再传入，如被hex编码，则hex解码后再传进来
     * @param data      待加密的数据
     * @return 返回der编码的密文数据
     * @throws InvalidCipherTextException
     * @throws IOException
     */
    public static byte[] encrypt(byte[] publicKey, byte[] data) throws InvalidCipherTextException, IOException {
        if (publicKey.length > 64) {
            publicKey = SM2Util.getPublicKey(publicKey);
        }

        if (publicKey.length == 64) {//首位填充0x04，标识未被压缩
            byte []tmp = new byte[65];
            System.arraycopy(publicKey, 0, tmp, 1, publicKey.length);
            tmp[0] = 0x04;
            publicKey = tmp;
        }
        ECDomainParameters parameters = new ECDomainParameters(sm2p256v1.getCurve(), sm2p256v1.getG(), sm2p256v1.getN());
        ECPublicKeyParameters pubKeyParameters = new ECPublicKeyParameters(sm2p256v1.getCurve().decodePoint(publicKey), parameters);
        SM2Engine engine = new SM2Engine();
        ParametersWithRandom pwr = new ParametersWithRandom(pubKeyParameters, new SecureRandom());
        engine.init(true, pwr);
        byte[] cipher = engine.processBlock(data, 0, data.length);
        return encodeSM2CipherToDER(cipher);
    }

    /**
     * sm2解密
     *
     * @param privateKey    私钥，二进制数据，若被编码的，请解码再传入，如被hex编码，则hex解码后再传进来
     * @param encryptedData 密文，二进制数据
     * @return 返回解密后的数据
     * @throws InvalidCipherTextException
     * @throws IOException
     */
    public static byte[] decrypt(byte[] privateKey, byte[] encryptedData) throws InvalidCipherTextException, IOException {
        if (privateKey.length > 32) {
            privateKey = SM2Util.getPrivateKey(privateKey);
        }

        ECDomainParameters parameters = new ECDomainParameters(sm2p256v1.getCurve(), sm2p256v1.getG(), sm2p256v1.getN());

        ECPrivateKeyParameters priKeyParameters = new ECPrivateKeyParameters(new BigInteger(1, privateKey), parameters);
        SM2Engine engine = new SM2Engine();
        engine.init(false, priKeyParameters);
        byte[] cipher = decodeDERSM2Cipher(encryptedData);
        return engine.processBlock(cipher, 0, cipher.length);
    }

    /**
     * sm2签名
     * <p>userId使用默认：1234567812345678
     *
     * @param privateKey 私钥，二进制数据
     * @param sourceData 待签名数据
     * @return 返回der编码的签名值
     * @throws CryptoException
     */
    public static byte[] sign(byte[] privateKey, byte[] sourceData) throws CryptoException, IOException {
        if (privateKey.length > 32) {
            privateKey = SM2Util.getPrivateKey(privateKey);
        }
        return sign(defaultUserID, privateKey, sourceData);
    }

    /**
     * sm2签名
     *
     * @param userId     ID值，若无约定，使用默认：1234567812345678
     * @param privateKey 私钥，二进制数据
     * @param sourceData 待签名数据
     * @return 返回der编码的签名值
     * @throws CryptoException
     */
    public static byte[] sign(byte[] userId, byte[] privateKey, byte[] sourceData) throws CryptoException {

        ECDomainParameters parameters = new ECDomainParameters(sm2p256v1.getCurve(), sm2p256v1.getG(), sm2p256v1.getN());
        ECPrivateKeyParameters priKeyParameters = new ECPrivateKeyParameters(new BigInteger(1, privateKey), parameters);
        SM2Signer signer = new SM2Signer();
        CipherParameters param = null;
        ParametersWithRandom pwr = new ParametersWithRandom(priKeyParameters, new SecureRandom());
        if (userId != null) {
            param = new ParametersWithID(pwr, userId);
        } else {
            param = pwr;
        }
        signer.init(true, param);
        signer.update(sourceData, 0, sourceData.length);
        return signer.generateSignature();
    }

    /**
     * sm2验签
     * <p>userId使用默认：1234567812345678
     *
     * @param publicKey  公钥，二进制数据
     * @param sourceData 待验签数据
     * @param signData   签名值
     * @return 返回是否成功
     */
    public static boolean verifySign(byte[] publicKey, byte[] sourceData, byte[] signData) {
        return verifySign(defaultUserID, publicKey, sourceData, signData);
    }

    /**
     * sm2验签
     *
     * @param userId     ID值，若无约定，使用默认：1234567812345678
     * @param publicKey  公钥，二进制数据
     * @param sourceData 待验签数据
     * @param signData   签名值
     * @return 返回是否成功
     */
    public static boolean verifySign(byte[] userId, byte[] publicKey, byte[] sourceData, byte[] signData) {
        if (publicKey.length > 64) {
            publicKey = SM2Util.getPublicKey(publicKey);
        }

        if (publicKey.length == 64) {
            byte[] tmp = new byte[65];
            System.arraycopy(publicKey, 0, tmp, 1, publicKey.length);
            tmp[0] = 0x04;
            publicKey = tmp;
        }

        ECDomainParameters parameters = new ECDomainParameters(sm2p256v1.getCurve(), sm2p256v1.getG(), sm2p256v1.getN());
        ECPublicKeyParameters pubKeyParameters = new ECPublicKeyParameters(sm2p256v1.getCurve().decodePoint(publicKey), parameters);
        SM2Signer signer = new SM2Signer();
        CipherParameters param;
        if (userId != null) {
            param = new ParametersWithID(pubKeyParameters, userId);
        } else {
            param = pubKeyParameters;
        }
        signer.init(false, param);
        signer.update(sourceData, 0, sourceData.length);
        return signer.verifySignature(signData);
    }

    /**
     * 获取der编码下的公钥
     *
     * @param derData der编码的公钥，二进制数据
     * @return 返回公钥值
     */
    public static byte[] getPublicKey(byte[] derData) {

        SubjectPublicKeyInfo info = SubjectPublicKeyInfo.getInstance(derData);
        return info.getPublicKeyData().getBytes();
    }

    /**
     * 获取der编码下的私钥
     *
     * @param derData der编码的私钥，二进制数据
     * @return 返回私钥值
     * @throws IOException
     */
    public static byte[] getPrivateKey(byte[] derData) throws IOException {

        PrivateKeyInfo pinfo = PrivateKeyInfo.getInstance(derData);
        ECPrivateKey cpk = ECPrivateKey.getInstance(pinfo.parsePrivateKey());

        int length = 32;
        byte[] bytes = cpk.getKey().toByteArray();
        if (bytes.length == length) {
            return bytes;
        }

        int start = bytes[0] == 0 ? 1 : 0;
        int count = bytes.length - start;

        if (count > length) {
            return null;
        }

        byte[] tmp = new byte[length];
        System.arraycopy(bytes, start, tmp, tmp.length - count, count);
        return tmp;
    }

    /**
     * 生成sm2公私钥对
     *
     * @return
     */
    public static SM2KeyPair generateKeyPair() {
        ECDomainParameters parameters = new ECDomainParameters(sm2p256v1.getCurve(), sm2p256v1.getG(), sm2p256v1.getN());
        KeyGenerationParameters kgp = new ECKeyGenerationParameters(parameters, new SecureRandom());
        ECKeyPairGenerator ecKeyPairGenerator = new ECKeyPairGenerator();
        ecKeyPairGenerator.init(kgp);

        ECPrivateKeyParameters ecpriv = null;
        ECPublicKeyParameters ecpub = null;
//		int count = 0;
        do {
            AsymmetricCipherKeyPair keypair = ecKeyPairGenerator.generateKeyPair();
            ecpriv = (ECPrivateKeyParameters) keypair.getPrivate();
            ecpub = (ECPublicKeyParameters) keypair.getPublic();
        } while (ecpriv == null || ecpriv.getD().equals(ECConstants.ZERO)
                || ecpriv.getD().compareTo(sm2p256v1.getN()) >= 0 || ecpriv.getD().signum() <= 0);

        byte[] privKey = formartBigNum(ecpriv.getD(), 32);
        byte[] pubxKey = formartBigNum(ecpub.getQ().getAffineXCoord().toBigInteger(), 32);
        byte[] pubyKey = formartBigNum(ecpub.getQ().getAffineYCoord().toBigInteger(), 32);
        byte[] pubKey = new byte[64];
        System.arraycopy(pubxKey, 0, pubKey, 0, pubxKey.length);
        System.arraycopy(pubyKey, 0, pubKey, pubxKey.length, pubyKey.length);
        return new SM2KeyPair(privKey, pubKey);
    }

    /**
     * 将c1c2c3密文转成der编码
     *
     * @param cipher c1c2c3密文
     * @return der编码的sm2密文
     * @throws IOException
     */
    public static byte[] encodeSM2CipherToDER(byte[] cipher)
            throws IOException {
        int startPos = 1;
        int curveLength = (sm2p256v1.getCurve().getFieldSize() + 7) / 8;
        int digestLength = SM3_DIGEST_LENGTH_32;

        byte[] c1x = new byte[curveLength];
        System.arraycopy(cipher, startPos, c1x, 0, c1x.length);
        startPos += c1x.length;

        byte[] c1y = new byte[curveLength];
        System.arraycopy(cipher, startPos, c1y, 0, c1y.length);
        startPos += c1y.length;

        byte[] c2 = new byte[cipher.length - c1x.length - c1y.length - 1 - digestLength];
        System.arraycopy(cipher, startPos, c2, 0, c2.length);
        startPos += c2.length;

        byte[] c3 = new byte[digestLength];
        System.arraycopy(cipher, startPos, c3, 0, c3.length);

        ASN1Encodable[] arr = new ASN1Encodable[4];
        arr[0] = new ASN1Integer(new BigInteger(1, c1x));
//        if (new BigInteger(1, c1x).toByteArray().length < 32) {
//            System.out.println("");
//        }
        arr[1] = new ASN1Integer(new BigInteger(1, c1y));
        arr[2] = new DEROctetString(c3);
        arr[3] = new DEROctetString(c2);
        DERSequence ds = new DERSequence(arr);
        return ds.getEncoded(ASN1Encoding.DER);
    }

    /**
     * 解DER编码密文（根据《SM2密码算法使用规范》 GM/T 0009-2012）
     *
     * @param derCipher 将der编码的sm2密文转成c1c2c3格式
     * @return 返回c1c2c3格式密文
     * @throws IOException
     */
    public static byte[] decodeDERSM2Cipher(byte[] derCipher) throws IOException {
        ByteArrayInputStream bis = new ByteArrayInputStream(derCipher);
        ASN1InputStream dis = new ASN1InputStream(bis);
//        ASN1Sequence as = DERSequence.getInstance(derCipher);
        ASN1Sequence as = (ASN1Sequence) dis.readObject();
        byte[] c1x = ((ASN1Integer) as.getObjectAt(0)).getValue().toByteArray();
        byte[] c1y = ((ASN1Integer) as.getObjectAt(1)).getValue().toByteArray();
        byte[] c3 = ((DEROctetString) as.getObjectAt(2)).getOctets();
        byte[] c2 = ((DEROctetString) as.getObjectAt(3)).getOctets();
        dis.close();

        int pos = 0;
        int curveLength = (sm2p256v1.getCurve().getFieldSize() + 7) / 8;
        byte[] cipherText = new byte[1 + curveLength * 2 + c2.length + c3.length];

        final byte uncompressedFlag = 0x04;
        cipherText[0] = uncompressedFlag;
        pos += 1;

        if (c1x.length >= curveLength) {
            System.arraycopy(c1x, c1x.length - curveLength, cipherText, pos, curveLength);
        } else {
            System.arraycopy(c1x, 0, cipherText, pos + curveLength - c1x.length, c1x.length);
        }
        pos += curveLength;

        if (c1y.length >= curveLength) {
            System.arraycopy(c1y, c1y.length - curveLength, cipherText, pos, curveLength);
        } else {
            System.arraycopy(c1y, 0, cipherText, pos + curveLength - c1y.length, c1y.length);
        }
        pos += curveLength;

        System.arraycopy(c2, 0, cipherText, pos, c2.length);
        pos += c2.length;

        System.arraycopy(c3, 0, cipherText, pos, c3.length);
        return cipherText;
    }

    /**
     * 格式化BigInteger，bg.toByteArray()获取到的字节数据长度不固定，因此需要格式化为固定长度
     *
     * @param bg         大数
     * @param needLength 所需要的长度
     * @return
     */
    private static byte[] formartBigNum(BigInteger bg, int needLength) {
        byte[] tmp = new byte[needLength];
        byte[] bgByte = bg.toByteArray();
        if (bgByte == null) {
            return null;
        }

        if (bgByte.length > needLength) {
            System.arraycopy(bgByte, bgByte.length - needLength, tmp, 0, needLength);
        } else if (bgByte.length == needLength) {
            tmp = bgByte;
        } else {
            System.arraycopy(bgByte, 0, tmp, needLength - bgByte.length, bgByte.length);
        }

        return tmp;
    }

    public static void KeyGen(String merName) throws Exception {
        SM2KeyPair keyPair = SM2Util.generateKeyPair();

        File dir = new File("cert/" + merName);
        if (!dir.exists()) {
            System.out.println("是否生成目录：" + dir.mkdirs());
        }

        FileHelper.write("cert/" + merName + File.separator + merName + ".pvk", keyPair.getPriByte());
        FileHelper.write("cert/" + merName + File.separator + merName + ".puk", keyPair.getPubByte());

        System.out.println("钥生成成功：" + merName);
        System.out.println("公钥Base64：" + keyPair.getBase64PubKey());
        System.out.println("私钥Base64：" + keyPair.getBase64PriKey());
    }
}
